{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Phenolopy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load packages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up a dask cluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "\n",
    "import os, sys\n",
    "import xarray as xr\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datacube\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from scipy.signal import savgol_filter, wiener\n",
    "from scipy.stats import zscore\n",
    "from statsmodels.tsa.seasonal import STL as stl\n",
    "from datacube.drivers.netcdf import write_dataset_to_netcdf\n",
    "\n",
    "sys.path.append('../Scripts')\n",
    "from dea_datahandling import load_ard\n",
    "from dea_dask import create_local_dask_cluster\n",
    "from dea_plotting import display_map, rgb\n",
    "\n",
    "sys.path.append('./scripts')\n",
    "import phenolopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialise the cluster. paste url into dask panel for more info.\n",
    "create_local_dask_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# open up a datacube connection\n",
    "dc = datacube.Datacube(app='phenolopy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Study area and data setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set study area and time range"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set lat, lon (y, x) dictionary of testing areas for gdv project\n",
    "loc_dict = {\n",
    "    'yan_full':   (-22.750, 119.10),\n",
    "    'yan_full_1': (-22.725, 119.05),\n",
    "    'yan_full_2': (-22.775, 119.15),\n",
    "    'roy_sign_1': (-22.618, 119.989),\n",
    "    'roy_full':   (-22.555, 120.01),\n",
    "    'roy_full_1': (-22.487, 119.927),\n",
    "    'roy_full_2': (-22.487, 120.092),\n",
    "    'roy_full_3': (-22.623, 119.927),\n",
    "    'roy_full_4': (-22.623, 120.092),\n",
    "    'oph_full':   (-23.280432, 119.859309),\n",
    "    'oph_full_1': (-23.375319, 119.859309),\n",
    "    'oph_full_2': (-23.185611, 119.859309),\n",
    "    'oph_full_3': (-23.233013, 119.859309),\n",
    "    'oph_full_4': (-23.280432, 119.859309),\n",
    "    'oph_full_5': (-23.327867, 119.859309),\n",
    "    'test':       (-31.6069288, 116.9426373)\n",
    "}\n",
    "\n",
    "# set buffer length and height (x, y)\n",
    "buf_dict = {\n",
    "    'yan_full':   (0.15, 0.075),\n",
    "    'yan_full_1': (0.09, 0.025),\n",
    "    'yan_full_2': (0.05, 0.0325),\n",
    "    'roy_sign_1': (0.15, 0.21),\n",
    "    'roy_full':   (0.33, 0.27),\n",
    "    'roy_full_1': (0.165209/2, 0.135079/2),\n",
    "    'roy_full_2': (0.165209/2, 0.135079/2),\n",
    "    'roy_full_3': (0.165209/2, 0.135079/2),\n",
    "    'roy_full_4': (0.165209/2, 0.135079/2),\n",
    "    'oph_full':   (0.08, 0.11863),\n",
    "    'oph_full_1': (0.08, 0.047452/2),\n",
    "    'oph_full_2': (0.08, 0.047452/2),\n",
    "    'oph_full_3': (0.08, 0.047452/2),\n",
    "    'oph_full_4': (0.08, 0.047452/2),\n",
    "    'oph_full_5': (0.08, 0.047452/2),\n",
    "    'test':       (0.05, 0.05)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select location from dict\n",
    "study_area = 'roy_full_2'\n",
    "\n",
    "# set buffer size in lon, lat (x, y)\n",
    "lon_buff, lat_buff = buf_dict[study_area][0], buf_dict[study_area][1]\n",
    "\n",
    "# select time range. for a specific year, set same year with month 01 to 12. multiple years will be averaged.\n",
    "time_range = ('2016-11', '2018-02')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select a study area from existing dict\n",
    "lat, lon = loc_dict[study_area][0], loc_dict[study_area][1]\n",
    "\n",
    "# combine centroid with buffer to form study boundary\n",
    "lat_extent = (lat - lat_buff, lat + lat_buff)\n",
    "lon_extent = (lon - lon_buff, lon + lon_buff)\n",
    "\n",
    "# display onto interacrive map\n",
    "display_map(x=lon_extent, y=lat_extent)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load sentinel-2a, b data for above parameters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set measurements (bands)\n",
    "measurements = [\n",
    "    'nbart_blue',\n",
    "    'nbart_green',\n",
    "    'nbart_red',\n",
    "    'nbart_nir_1',\n",
    "    'nbart_swir_2'\n",
    "]\n",
    "\n",
    "# create query from above and expected info\n",
    "query = {\n",
    "    'x': lon_extent,\n",
    "    'y': lat_extent,\n",
    "    'time': time_range,\n",
    "    'measurements': measurements,\n",
    "    'output_crs': 'EPSG:3577',\n",
    "    'resolution': (-10, 10),\n",
    "    'group_by': 'solar_day',\n",
    "}\n",
    "\n",
    "# load sentinel 2 data\n",
    "ds = load_ard(\n",
    "    dc=dc,\n",
    "    products=['s2a_ard_granule', 's2b_ard_granule'],\n",
    "    min_gooddata=0.90,\n",
    "    dask_chunks={'time': 1},\n",
    "    **query\n",
    ")\n",
    "\n",
    "# display dataset\n",
    "print(ds)\n",
    "\n",
    "# display a rgb data result of temporary resampled median \n",
    "#rgb(ds.resample(time='1M').median(), bands=['nbart_red', 'nbart_green', 'nbart_blue'], col='time', col_wrap=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conform DEA band names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes our dask ds and conforms (renames) bands\n",
    "ds = phenolopy.conform_dea_band_names(ds)\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate vegetation index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes our dask ds and calculates veg index from spectral bands\n",
    "ds = phenolopy.calc_vege_index(ds, index='mavi', drop=True)\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-processing phase"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Temporary - load MODIS dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#ds = phenolopy.load_test_dataset(data_path='./data/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# resample to bimonth\n",
    "ds = phenolopy.resample(ds, interval='1M', reducer='median')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# interp\n",
    "ds = ds.chunk({'time': -1})\n",
    "ds = phenolopy.interpolate(ds=ds, method='interpolate_na')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop years\n",
    "ds = ds.where(ds['time.year'] == 2017, drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Group data by month and reduce by median"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# take our dask ds and group and reduce dataset in median weeks (26 for one year)\n",
    "ds = phenolopy.group(ds, group_by='month', reducer='median')\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show times\n",
    "ds = ds.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Remove outliers from dataset on per-pixel basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chunk dask to -1 to make compatible with this function\n",
    "ds = ds.chunk({'time': -1})\n",
    "\n",
    "# takes our dask ds and remove outliers from data using median method\n",
    "ds = phenolopy.remove_outliers(ds=ds, method='median', user_factor=2, z_pval=0.05)\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Resample dataset down to bi-monthly medians"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# takes our dask ds and resamples data to bi-monthly medians\n",
    "ds = phenolopy.resample(ds, interval='1W', reducer='median')\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolate missing (i.e. nan) values linearly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chunk dask to -1 to make compatible with this function\n",
    "ds = ds.chunk({'time': -1})\n",
    "\n",
    "# takes our dask ds and interpolates missing values\n",
    "ds = phenolopy.interpolate(ds=ds, method='interpolate_na')\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Smooth data on per-pixel basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chunk dask to -1 to make compatible with this function\n",
    "ds = ds.chunk({'time': -1})\n",
    "\n",
    "# take our dask ds and smooth using savitsky golay filter\n",
    "ds = phenolopy.smooth(ds=ds, method='savitsky', window_length=3, polyorder=1)\n",
    "\n",
    "# display dataset\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upper envelope correction\n",
    "todo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# todo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate number of seasons "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# chunk dask to -1 to make compatible with this function\n",
    "ds = ds.chunk({'time': -1})\n",
    "\n",
    "# take our dask ds and smooth using savitsky golay filter\n",
    "da_num_seasons = phenolopy.calc_num_seasons(ds=ds)\n",
    "\n",
    "# display dataset\n",
    "print(da_num_seasons)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate Phenolometrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute\n",
    "ds = ds.compute()\n",
    "print(ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%autoreload\n",
    "\n",
    "# calc phenometrics via phenolopy!\n",
    "ds_phenos = phenolopy.calc_phenometrics(da=ds['veg_index'], peak_metric='pos', base_metric='vos', method='seasonal_amplitude', factor=0.2, thresh_sides='two_sided', abs_value=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set the metric you want to view\n",
    "metric_name = 'lios_values'\n",
    "\n",
    "# plot this on map\n",
    "ds_phenos[metric_name].plot(robust=True, cmap='Spectral')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datacube.drivers.netcdf import write_dataset_to_netcdf\n",
    "write_dataset_to_netcdf(ds_phenos, 'roy_2017_1w_phenos.nc')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up params\n",
    "import random\n",
    "import shutil\n",
    "\n",
    "# set output filename\n",
    "filename = 'roy_2_p_pos_b_vos_seas_amp_f_015'\n",
    "\n",
    "# set seed \n",
    "random.seed(50)\n",
    "\n",
    "# gen random x and y lists for specified num pixels (e.g. 250 x, 250 y)\n",
    "n_pixels = 200\n",
    "x_list = random.sample(range(0, len(ds_phenos['x'])), n_pixels)\n",
    "y_list = random.sample(range(0, len(ds_phenos['y'])), n_pixels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_test(ds_raw, ds_phen, filename, x_list, y_list):\n",
    "    \n",
    "    # loop through each pixel pair\n",
    "    for x, y in zip(x_list, y_list):\n",
    "        \n",
    "        # get pixel and associate phenos pixel\n",
    "        v = ds_raw.isel(x=x, y=y)\n",
    "        p = ds_phen.isel(x=x, y=y)\n",
    "        \n",
    "        # create fig\n",
    "        fig = plt.figure(figsize=(12, 5))\n",
    "\n",
    "        # plot main trend\n",
    "        plt.plot(v['time.dayofyear'], v['veg_index'], linestyle='solid', marker='.', color='black')\n",
    "\n",
    "        # plot pos vals and times\n",
    "        plt.plot(p['pos_times'], p['pos_values'], \n",
    "                 marker='o', linestyle='', color='blue', label='POS')\n",
    "        plt.annotate('POS', (p['pos_times'], p['pos_values']))\n",
    "\n",
    "        # plot vos vals and times\n",
    "        plt.plot(p['vos_times'], p['vos_values'], \n",
    "                 marker='o', linestyle='', color='darkred', label='VOS')\n",
    "        plt.annotate('VOS', (p['vos_times'], p['vos_values']))\n",
    "\n",
    "        # plot bse vals\n",
    "        plt.axhline(p['bse_values'], \n",
    "                    marker='', linestyle='dashed', color='red', label='BSE')\n",
    "        # add legend\n",
    "\n",
    "        # plot sos vals and times\n",
    "        plt.plot(p['sos_times'], p['sos_values'], \n",
    "                 marker='s', linestyle='', color='green', label='SOS')\n",
    "        plt.annotate('SOS', (p['sos_times'], p['sos_values']))\n",
    "\n",
    "        # plot eos vals and times\n",
    "        plt.plot(p['eos_times'], p['eos_values'], \n",
    "                 marker='s', linestyle='', color='orange', label='EOS')\n",
    "        plt.annotate('EOS', (p['eos_times'], p['eos_values']))\n",
    "\n",
    "        # plot aos vals\n",
    "        plt.axvline(p['pos_times'], \n",
    "                    marker='', color='magenta', linestyle='dotted', label='AOS')\n",
    "\n",
    "        # plot los vals\n",
    "        plt.axhline((p['sos_values'] + p['eos_values']) / 2, \n",
    "                    marker='', color='yellowgreen', linestyle='dashdot', label='LOS')\n",
    "\n",
    "        # plot sios\n",
    "        plt.fill_between(v['time.dayofyear'], v['veg_index'], y2=p['bse_values'],\n",
    "                         color='red', alpha=0.1, label='SIOS')\n",
    "\n",
    "        # plot lios\n",
    "        t = ~v.where((v['time.dayofyear'] >= p['sos_times']) & (v['time.dayofyear'] <= p['eos_times'])).isnull()\n",
    "        plt.fill_between(v['time.dayofyear'], v['veg_index'], where=t['veg_index'],\n",
    "                          color='yellow', alpha=0.2, label='LIOS')    \n",
    "\n",
    "        # plot siot\n",
    "        plt.fill_between(v['time.dayofyear'], v['veg_index'], y2=p['bse_values'],\n",
    "                         color='aqua', alpha=0.3, label='SIOT')\n",
    "\n",
    "        # plot liot\n",
    "        plt.fill_between(v['time.dayofyear'], v['veg_index'],\n",
    "                         color='aqua', alpha=0.1, label='LIOT')\n",
    "\n",
    "        # add legend\n",
    "        plt.legend(loc='best')\n",
    "        \n",
    "        # create output filename\n",
    "        out = os.path.join('testing', filename + '_x_' + str(x) + '_y_' + str(y) + '.jpg')\n",
    "\n",
    "        # save to file without plotting\n",
    "        fig.savefig(out)\n",
    "        plt.close()\n",
    "        \n",
    "    # export as zip\n",
    "    shutil.make_archive(filename + '.zip', 'zip', './testing')\n",
    "\n",
    "    # clear all files in dir\n",
    "    for root, dirs, files in os.walk('./testing'):\n",
    "        for file in files:\n",
    "            os.remove(os.path.join(root, file))\n",
    "\n",
    "# perform test\n",
    "run_test(ds_raw=ds, ds_phen=ds_phenos, filename=filename, x_list=x_list, y_list=y_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datacube.utils.cog import write_cog\n",
    "\n",
    "write_cog(geo_im=ds_phenos['lios_values'], fname='lios.tif', overwrite=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Working"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# different types of detection, using stl residuals - remove outlier method\n",
    "#from scipy.stats import median_absolute_deviation\n",
    "\n",
    "#v = ds.isel(x=0, y=0, time=slice(0, 69))\n",
    "#v['veg_index'].data = data\n",
    "\n",
    "#v_med = remove_outliers(v, method='median', user_factor=1, num_dates_per_year=24, z_pval=0.05)\n",
    "#v_zsc = remove_outliers(v, method='zscore', user_factor=1, num_dates_per_year=24, z_pval=0.1)\n",
    "\n",
    "#stl_res = stl(v['veg_index'], period=24, seasonal=5, robust=True).fit()\n",
    "#v_rsd = stl_res.resid\n",
    "#v_wgt = stl_res.weights\n",
    "\n",
    "#o = v.copy()\n",
    "#o['veg_index'].data = v_rsd\n",
    "\n",
    "#w = v.copy()\n",
    "#w['veg_index'].data = v_wgt\n",
    "\n",
    "#m = xr.where(o > o.std('time'), True, False)\n",
    "#o = v.where(m)\n",
    "\n",
    "#m = xr.where(w < 1e-8, True, False)\n",
    "#w = v.where(m)\n",
    "\n",
    "#fig = plt.figure(figsize=(18, 7))\n",
    "#plt.plot(v['time'], v['veg_index'], color='black', marker='o')\n",
    "#plt.plot(o['time'], o['veg_index'], color='red', marker='o', linestyle='-')\n",
    "#plt.plot(w['time'], w['veg_index'], color='blue', marker='o', linestyle='-')\n",
    "#plt.axhline(y=float(o['veg_index'].std('time')))\n",
    "#plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# working method for stl outlier dection. can't quite get it to match timesat results?\n",
    "# need to speed this up - very slow for even relatively small datasets\n",
    "#def func_stl(vec, period, seasonal, jump_l, jump_s, jump_t):\n",
    "    #resid = stl(vec, period=period, seasonal=seasonal, \n",
    "                #seasonal_jump=jump_s, trend_jump=jump_t, low_pass_jump=jump_l).fit()\n",
    "    #return resid.resid\n",
    "\n",
    "#def do_stl_apply(da, multi_pct, period, seasonal):\n",
    "    \n",
    "    # calc jump size for lowpass, season and trend to speed up processing\n",
    "    #jump_l = int(multi_pct * (period + 1))\n",
    "    #jump_s = int(multi_pct * (period + 1))\n",
    "    #jump_t = int(multi_pct * 1.5 * (period + 1))\n",
    "    \n",
    "    #f = xr.apply_ufunc(func_stl, da,\n",
    "                       #input_core_dims=[['time']], \n",
    "                       #output_core_dims=[['time']], \n",
    "                       #vectorize=True, dask='parallelized', \n",
    "                       #output_dtypes=[ds['veg_index'].dtype],\n",
    "                       #kwargs={'period': period, 'seasonal': seasonal, \n",
    "                               #'jump_l': jump_l, 'jump_s': jump_s, 'jump_t': jump_t}) \n",
    "    #return f\n",
    "\n",
    "# chunk up to make use of dask parallel\n",
    "#ds = ds.chunk({'time': -1})\n",
    "\n",
    "# calculate residuals for each vector  stl\n",
    "#stl_resids = do_stl_apply(ds['veg_index'], multi_pct=0.15, period=24, seasonal=13)\n",
    "\n",
    "#s = ds['veg_index'].stack(z=('x', 'y'))\n",
    "#s = s.chunk({'time': -1})\n",
    "#s = s.groupby('z').map(func_stl)\n",
    "#out = out.unstack()\n",
    "\n",
    "#s = ds.chunk({'time': -1})\n",
    "#t = xr.full_like(ds['veg_index'], np.nan)\n",
    "#out = xr.map_blocks(func_stl, ds['veg_index'], template=t).compute()\n",
    "\n",
    "#stl_resids = stl_resids.compute()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# working double logistic - messy though\n",
    "# https://colab.research.google.com/github/1mikegrn/pyGC/blob/master/colab/Asymmetric_GC_integration.ipynb#scrollTo=upaYKFdBGEAo\n",
    "# see for asym gaussian example\n",
    "\n",
    "#da = v.where(v['time.year'] == 2016, drop=True)\n",
    "\n",
    "#def logi(x, a, b, c, d):\n",
    "    #return a / (1 + xr.ufuncs.exp(-c * (x - d))) + b\n",
    "\n",
    "# get date at max veg index\n",
    "#idx = int(da['veg_index'].argmax())\n",
    "\n",
    "# get left and right of peak of season\n",
    "#da_l = da.where(da['time'] <= da['time'].isel(time=idx), drop=True)\n",
    "#da_r = da.where(da['time'] >= da['time'].isel(time=idx), drop=True)\n",
    "\n",
    "# must sort right curve (da_r) descending to flip data\n",
    "#da_r = da_r.sortby(da_r['time'], ascending=False)\n",
    "\n",
    "# get indexes of times (times not compat with exp)\n",
    "#da_l_x_idxs = np.arange(1, len(da_l['time']) + 1, step=1)\n",
    "#da_r_x_idxs = np.arange(1, len(da_r['time']) + 1, step=1)\n",
    "\n",
    "# fit curve\n",
    "#popt_l, pcov_l = curve_fit(logi, da_l_x_idxs, da_l['veg_index'], method=\"trf\")\n",
    "#popt_r, pcov_r = curve_fit(logi, da_r_x_idxs, da_r['veg_index'], method=\"trf\")\n",
    "\n",
    "# apply fit to original data\n",
    "#da_fit_l = logi(da_l_x_idxs, *popt_l)\n",
    "#da_fit_r = logi(da_r_x_idxs, *popt_r)\n",
    "\n",
    "# flip fitted vector back to original da order\n",
    "#da_fit_r = np.flip(da_fit_r)\n",
    "\n",
    "# get mean of pos value, remove overlap between l and r\n",
    "#pos_mean = (da_fit_l[-1] + da_fit_r[0]) / 2\n",
    "#da_fit_l = np.delete(da_fit_l, -1)\n",
    "#da_fit_r = np.delete(da_fit_r, 1)\n",
    "\n",
    "# concat back together with mean val inbetween\n",
    "#da_logi = np.concatenate([da_fit_l, pos_mean, da_fit_r], axis=None)\n",
    "\n",
    "# smooth final curve with mild savgol\n",
    "#da_logi = savgol_filter(da_logi, 3, 1)\n",
    "\n",
    "#fig = plt.subplots(1, 1, figsize=(6, 4))\n",
    "#plt.plot(da['time'], da['veg_index'], 'o')\n",
    "#plt.plot(da['time'], da_logi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "#from scipy.signal import find_peaks\n",
    "\n",
    "#x, y = 0, 1\n",
    "\n",
    "#v = da.isel(x=x, y=y)\n",
    "\n",
    "#height = float(v.quantile(dim='time', q=0.75))\n",
    "#distance = math.ceil(len(v['time']) / 4)\n",
    "\n",
    "#p = find_peaks(v, height=height, distance=distance)[0]\n",
    "\n",
    "#p_dts = v['time'].isel(time=p)\n",
    "\n",
    "#for p_dt in p_dts:\n",
    "    #plt.axvline(p_dt['time'].dt.dayofyear, color='black', linestyle='--')\n",
    "\n",
    "#count_peaks = len(num_peaks[0])\n",
    "#if count_peaks > 0:\n",
    "    #return count_peaks\n",
    "#else:\n",
    "    #return 0\n",
    "    \n",
    "#plt.plot(v['time.dayofyear'], v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# flip to get min closest to pos\n",
    "# if we want closest sos val to pos we flip instead to trick argmin\n",
    "#flip = dists_sos_v.sortby(dists_sos_v['time'], ascending=False)\n",
    "#min_right = flip.isel(time=flip.argmin('time'))\n",
    "#temp_pos_cls = da.isel(x=x, y=0).where(da['time'] == min_right['time'].isel(x=x, y=0))\n",
    "#plt.plot(temp_pos_cls.time, temp_pos_cls, marker='o', color='black', alpha=0.25)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
